#include "plato.h"

static plato::plato_malloc_func malloc_func_ = malloc;
static plato::plato_free_func free_func_ = free;

auto plato::set_mem_allocator(plato_malloc_func malloc_func,
                              plato_free_func free_func) -> void {
  malloc_func_ = malloc_func;
  free_func_ = free_func;
}

auto plato::get_malloc() -> plato_malloc_func {
  if (!malloc_func_) {
    return malloc;
  }
  return malloc_func_;
}

auto plato::get_free() -> plato_free_func {
  if (!free_func_) {
    return free;
  }
  return free_func_;
}

namespace plato {
class MemBlockImpl : public MemBlock {
  char *block_{nullptr};
  std::size_t pos_{0};
  std::size_t max_{0};

public:
  MemBlockImpl() {
    max_ = 1024 * 16;
    block_ = (char *)(get_malloc()(max_));
  }
  virtual ~MemBlockImpl() {
    if (block_) {
      get_free()(block_);
    }
  }
  virtual auto reset(std::size_t size) -> void override {
    if (block_) {
      get_free()(block_);
    }
    block_ = (char *)(get_malloc()(size));
    max_ = size;
    pos_ = 0;
  }
  virtual auto get(std::size_t size) -> void * override {
    if (pos_ + size >= max_) {
      return nullptr;
    }
    auto *m = (void *)(block_ + pos_);
    pos_ += size;
    return m;
  }
};
} // namespace plato

auto plato::new_mem_block() -> std::unique_ptr<MemBlock> {
  return std::unique_ptr<MemBlockImpl>(new MemBlockImpl());
}

auto plato::new_mem_block_raw() -> MemBlock * { return new MemBlockImpl(); }
